Skip to content

TE EP integration to MoEBlock#3116

Open
tdophung wants to merge 63 commits into
NVIDIA:mainfrom
tdophung:teddy/te_ep_integration
Open

TE EP integration to MoEBlock#3116
tdophung wants to merge 63 commits into
NVIDIA:mainfrom
tdophung:teddy/te_ep_integration

Conversation

@tdophung

@tdophung tdophung commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Description

Integrate the new TE EP dispatch + combine APIs in the MoEBlock, in replacement for the previous ragged-all-to-all communications and Triton permutation kernels. Router CUDA kernels + grouped GEMM is the same as before.

Will rebase and squash the commits on this branch once about to merge
Will also change the JAX APIs if needed when TE EP JAX merge

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Remove previous multiprocess test script for MoE VJP in general
  • Add new multiprocess test script for MoEBlock x TE EP
  • Integrate TE EP dispatch + combine APIs into MoeBlock

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng added 2 commits June 9, 2026 18:27
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@tdophung tdophung force-pushed the teddy/te_ep_integration branch from 0ff3bff to bd14fe6 Compare June 10, 2026 21:58
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
# Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0
# uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM
# tile alignment.
align_size: int = 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user

Comment thread transformer_engine/jax/flax/moe.py Outdated
nn.with_logical_partitioning(self.bias_init, ("exp",)),
(self.num_experts,),
self.dtype,
jnp.float32,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the router always in fp32 so this expert bias must also be? If so, can we add a small comment indicating this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I will add a comment



__all__ = ["moe", "PermutationBackend"]
def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this utility function? I haven't seen something like this required for our other VJPs

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for MoE particularly, In our _moe_bwd_rule, d_x is built from two cotangent paths:

d_x_from_dispatch = tex.ep_dispatch_bwd(...) in bf16 if x is in bf16.
and
d_x_from_gate = d_logits_2d @ gate_kernel^T, where d_logits_2d comes from tex.fused_topk_with_score_function_bwd, which only ooutput the logits in fp32 (per alp in one of the weekly meetings + I checked the code for the router kernel). So d_x_from_gate is fp32.

Therefore, d_x = d_x_from_dispatch + d_x_from_gate = bf16 + fp32 = fp32, while our x could be bf16. So the cotangent that flows to bwd needs to be constrainted. Other VJPs don't have this issue because the dgrad will just be the same dtype as the activation dtype.

# is a frozen dataclass of ints); the rest are jnp.ndarray,
# GroupedNoScaleTensor (already a pytree), or None when aux_loss_coeff == 0.
@register_pytree_node_class
@dataclass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this tree_flatten was from my patch, but looking at the diff I think it'd be better to use the @flax_struct.dataclass you were using on the permutation dataclasses since that seems to auto-populate a default pytree flatten/unflatten for us

Comment thread transformer_engine/jax/moe.py Outdated
else:
d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat)

# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, after confirming with Tim and Przemek, we should kjeep activation input in BF16 and activation output in GEMM's dtype. In latest commit, this should be BF16

# local expert. We must size to that worst case or NCCL EP's HT kernel
# rejects the dispatch buffer with ``invalid argument``.
natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S
# NCCL EP requires each expert-major output block to be at least

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a use-case for user-specified alignments beyond 128 currently? If NCCL EP requires an alignment of at least 128, and since an alignment of 128 is sufficient for all TE grouped GEMM types, would it make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API.

We can always expand the API to support a user-specified align size in the future

batch_pspec_axis = (*data_parallelism_axes, ep_axis)
ep3_spec = P(batch_pspec_axis, None, None)
ep2_spec = P(batch_pspec_axis, None)
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which axis name inputs are physical mesh axes and why can be logical axes? I see above x = with_sharding_constraint_by_logical_axes(x, input_axes) but here we directly use jax.lax.with_sharding_constraint which only supports mesh axes.

No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes. Thanks!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some comments

Comment thread transformer_engine/jax/moe.py Outdated
# `grad_pre_combine * w` sees them. Padded positions in sparse_probs
# are already zero (routing_map is False there); only the rare
# underflow path emits NaN.
sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this NaN filtering a debugging artifact or something we need in the final version?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging artifact. Remopving

tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…IA#3116)

Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias ->
use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The
two flags are siblings (they enable two different bias buffers) but
the old names suggested ``use_bias`` was the general fallback, which
wasn't the intent. The new names make the FFN-vs-routing distinction
obvious from the call site.

* transformer_engine/jax/flax/moe.py
    use_bias -> use_ffn_bias  (dataclass field + branch in __call__
    + docstring entry)
    use_expert_bias -> use_expert_routing_bias  (same)
* tests/jax/test_te_ep_moe.py
    _make_block(use_expert_bias=...) -> use_expert_routing_bias
    sigmoid-bias-strong config key updated
    _reference_kwargs_from_config now reads use_expert_routing_bias

``_MoEBlock`` is still the experimental underscore-prefixed alias
(no public ``MoEBlock`` export yet), so the rename is API-safe.

The pre-resync legacy tests (``test_moe_vjp.py``,
``test_multiprocess_moe_vjp.py``) are intentionally not updated --
they already reference removed APIs like ``PermutationBackend`` and
need a separate post-resync cleanup pass.

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…and inline justifications)

Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on
``transformer_engine/jax/moe.py``. All changes are confined to a
single file because each review thread targets a localized region
and splitting mid-file would risk reordering bugs.

Per review thread:

1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't
    seen something like this required for our other VJPs."
   -- Expand the helper's docstring to spell out exactly why MoE
   needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent
   from ep_dispatch_bwd with an fp32 cotangent from
   fused_topk_with_score_function_bwd (which the fwd's
   logits_2d -> fp32 promotion forces). Without the cast, ``d_x``
   surfaces at fp32 even when ``x`` is bf16, doubling activation
   grad bandwidth and breaking any downstream LN bwd that pins a
   bf16 layout. (Review thread "Why do we need this utility
   function?".)

2. "Why is this dtype casting required? I don't recall us needing
    it for the non-MoE LNMLP block."
   -- Expand the comment above the bwd activation fp32 promotion
   to explain the MoE-specific math: LN+MLP's silu sits behind a
   downstream LN that absorbs the bf16 rounding error, while
   MoE's silu sits on the *expert* side of routing -- the bf16
   rounding rides directly into expert_outputs and is summed
   across topk experts by ep_combine. Bf16 silu alone drifts ~1%
   vs fp32 silu and compounds through wo->combine into the ~1.4%
   per-element parity gap we measured against the pure-JAX
   softmax reference. Mirroring the fwd's fp32 promotion in the
   bwd keeps silu' in lock-step with silu. (Review thread on
   "# Activation bwd. Mirror the fwd's fp32 promotion of
   silu+multiply".)

3. "Do we have a use-case for user-specified alignments beyond
    128 currently? ... it'd make sense to instead hardcode
    _ALIGN_SIZE = 128 as a constant at the top of the file for
    now to simplify this MoEBlock API. We can always expand the
    API to support a user-specified align size in the future."
   -- Implement the suggestion. Drop ``align_size`` from
   ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public
   ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from
   ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align
   = max(int(align_size), 128)`` with the new module-level
   ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring
   accordingly. (Review thread on
   "natural_spe = num_ep * max_tokens_per_rank".)

4. "Which axis name inputs are physical mesh axes and why can be
    logical axes? ... No need to make any changes for now, I just
    want to assess which are which and then we can discuss if it
    makes sense to support logical on some/all or if some are
    required to be physical axes."
   -- Add an "Axis-name parameters" section to ``moe()``'s
   docstring listing which kwargs are physical mesh axes
   (``ep_axis``, ``data_parallelism_axes`` -- they index
   ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size``
   and to construct the ``P((dp..., ep), None, None)`` for
   ``jax.lax.with_sharding_constraint``) vs logical axes
   (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
   ``wo_kernel_axes`` -- resolved via the Flax logical-axis
   rules). Also document why ``ep_axis`` / ``data_parallelism_axes``
   are intentionally non-logical: the EP comm-group construction
   (``dp_color = rank // ep_size``) and the bootstrap signature
   check both require concrete integer sizes. (Review thread on
   "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".)

5. "Is this NaN filtering a debugging artifact or something we
    need in the final version?"
   -- Strengthen the inline comment above
   ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)``
   to explicitly call this out as a CORRECTNESS REQUIREMENT, not
   a debugging artifact: it covers the sigmoid+K>1 underflow
   path where top-K sigmoid scores all round to zero and the
   ``weights / (weights.sum + 1e-20)`` normalisation emits NaN.
   Observationally the filter is a no-op on the dense unit-test
   distributions, but it must stay in for sparse / production
   routing. (Review thread on
   "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).")

Not addressed in this commit (intentional):

* Review thread on the ``align_size: int = 0`` placeholder in
  ``flax/moe.py`` ("Placeholder comment for me to fix this so
  align_size is inferred automatically based on the recipe and
  doesn't need to be specified by the user"). That's
  jberchtold's own follow-up.
* Review thread on the explicit ``tree_flatten`` /
  ``tree_unflatten`` on ``_Ctx`` ("better to use the
  ``@flax_struct.dataclass``"). Deferred to a separate, testable
  commit because changing a ``custom_vjp`` residual's pytree
  registration touches subtle ordering / None-handling semantics
  that warrant their own bisect surface.
* Review thread on ``use_bias`` / ``use_expert_bias`` renames --
  handled in the immediately preceding commit
  ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``.
* Review thread on the ``expert_bias`` fp32 init -- already
  resolved during the Phuong PR NVIDIA#3036 resync (the redundant
  ``jnp.float32`` second-dtype argument on ``self.param`` was
  dropped; ``expert_bias`` now lives at ``self.dtype``).

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…aN sanitizer

* Rewrite the inline justifications added in 078a7d80 so each one
  reads as standalone code documentation, not as a reply to a
  reviewer: drop "per PR NVIDIA#3116 review", "review feedback",
  "Renamed from ... per PR ..." and similar PR/thread references
  from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py.
  Technical content (why the fp32 promotion is needed for the MoE
  silu+multiply, why _with_sharding_constraint_cast_bwd exists,
  physical-vs-logical axis split in moe() docstring, the 128
  alignment rationale) is preserved and reframed to be useful to
  a reader who has no PR context.

* Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs)
  guard. Tracing fused_topk_with_score_function.cu shows the
  kernel divides by sum_scores + 1e-20, so finite non-negative
  sigmoid scores cannot produce NaN here; the filter was only
  defense against upstream NaNs, which would mask a real
  regression if anything ever did start producing them.

Signed-off-by: tdophung <tdophung@nvidia.com>
pre-commit-ci Bot and others added 23 commits June 11, 2026 17:15
for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…er + NVTEEpHandle struct (NVTE_EP_HANDLE_CACHE_SIZE=-1 disables eviction)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…hout it

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ogging.h

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…_COPY_{ON,OFF}

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tyAllSymm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CUDA Toolkit)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… for wheel install

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…bmodules

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rop submodule header mirror

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…al CommWindow

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
phu0ngng and others added 23 commits June 11, 2026 17:15
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)

Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
…ero)

* drop ``TestZZZTeEpMoeBootstrap``: the re-bootstrap mismatch is a
  one-line guard in ``ep_bootstrap`` and not the MoE block's concern;
  exercising it from this suite also taints the per-process NCCL
  bootstrap cache for the rest of the file with no real upside.
* drop ``TestTeEpMoEBlockFlax::test_init_apply_parity``: every config
  in ``_CONFIGS`` already runs ``MoEBlock`` (the Flax wrapper)
  end-to-end via ``test_forward`` / ``test_backward``, so this was a
  duplicate of ``softmax`` parity in another wrapper -- leave wrapper
  refactors to devs without paying for an extra CI run each time.
* drop ``sigmoid-bias-zero``: with a zero-init bias buffer the routing
  math collapses to the no-bias case, so ``sigmoid`` already covers
  that numerical path. The bias-aware codepath is still exercised by
  ``sigmoid-bias-strong`` (non-zero bias).
* refresh the module-level docstring to list intentional
  non-coverage so future readers don't re-add these tests.

Signed-off-by: tdophung <tdophung@nvidia.com>
… closure)

Two unrelated one-line bugs in the bwd custom_partitioning machinery
that only surface once the MoE block's aux-loss path is lifted out of
shard_map (the custom_partitioning_sharding_rule check is skipped under
shard_map, which is why these never tripped before).

1. FusedMoEAuxLossBwdPrimitive.shardy_sharding_rule:
   ``grad_aux_loss`` is the cotangent of a scalar loss and is rank-0;
   declaring it with a spurious ``grad_one`` factor gave it rank-1 and
   tripped JAX's custom_partitioning_sharding_rule rank check at global
   view. Change the rule's third operand entry to empty:

     "const_buf_one, num_experts, grad_one -> i num_experts"
   ->
     "const_buf_one, num_experts, -> i num_experts"

2. FusedTopkWithScoreFunctionBwdPrimitive.partition:
   ``del result_infos, routing_map_format`` removed
   ``routing_map_format`` from the enclosing scope before the nested
   ``sharded_impl`` closure was invoked. Python closures resolve names
   at call time, not definition time, so when XLA finally invoked
   ``sharded_impl`` for the bwd partitioned impl it raised
   ``NameError: cannot access free variable 'routing_map_format'``.
   Drop ``routing_map_format`` from the ``del`` and leave a NOTE so
   future cleanups don't reintroduce the bug. Sibling partition
   methods (fwd topk, both aux-loss directions) already only
   ``del result_infos`` and need no change.

Signed-off-by: tdophung <tdophung@nvidia.com>
A dp_resource or fsdp_resource that exists in the active mesh resource
config but is sized 1 in the actual mesh would still be returned by
``_ep_outer_axis()``, pinning EP-output PartitionSpecs to a degenerate
axis. JAX collapses size-1 mesh axes during lowering, which made the
EP-output specs reference an axis that no longer exists at runtime --
breaking shard_map output stitching on configs where DP or FSDP is
optional.

Treat a size-1 axis as absent: prefer dp -> fsdp, but only when the
candidate axis is actually sized > 1 in the current mesh. Falls back
to the previous behaviour when no axis is configured at all.

Signed-off-by: tdophung <tdophung@nvidia.com>
After the upstream PR NVIDIA#3036 resync the moe() API surface lost
PermutationBackend (TE-EP is the only backend now), gate_inside_vjp
(always True), and the per-call quantizer_sets knob (quantization
flows through the standard TE autocast / with_quantizer_set context).
It also gained apply_topk_weights_early and renamed the wrapper's
private _align_size to the public align_size the test suite already
uses. The Flax _MoEBlock wrapper was still passing the old kwargs,
which broke every test that touched the wrapper.

Wrapper changes:
  * drop "from ..moe import PermutationBackend" plus the dataclass
    field, the isinstance(..., PermutationBackend) validation in
    __post_init__, and the pass-through to moe().
  * drop "from ..quantize import noop_quantizer_set" and the
    quantizer_sets=(noop, noop, noop) pass-through.
  * drop gate_inside_vjp=True.
  * rename _align_size: int = 0 -> align_size: int = 0 (matches
    what tests/jax/test_te_ep_moe.py already passes).
  * add apply_topk_weights_early: bool = False and pass it through
    to moe().
  * refresh class docstring: drop permutation_backend / _align_size
    / quantizer_sets descriptions, add apply_topk_weights_early /
    align_size, note that quantization currently flows only through
    fp8_autocast.

Signed-off-by: tdophung <tdophung@nvidia.com>
…ices

Two correctness fixes for the TE-EP MoE custom_vjp that together let
the bwd parity tests pass on 0-token-globally experts, and drop a
workaround that is no longer needed.

(1) Plumb per-expert padded token_counts into grouped_gemm group_sizes.

NCCL EP HT dispatch lays out recv_tokens expert-major as
  [expert_0_padded | expert_1_padded | ... | overalloc_tail]
where each per-expert block already includes the
dispatch_output_per_expert_alignment zero-padding and only the trailing
overalloc tail (slack between sum(token_counts) and the worst-case
recv_pr) is unused. Previously _ffn_fwd_per_shard built a static
local_group_sizes = jnp.full((num_local_experts,), slots_per_expert),
which over-counted by the overalloc tail and forced cuBLAS to run the
GEMM for every group including 0-token-routed experts.

Pipe the real per-shard token_counts (1, num_local_experts) from
ep_prepare through _moe_fwd_rule (added to ffn_in_specs/ffn_in_args
with ep2_spec), into _ffn_fwd_per_shard as token_counts_local, and
reshape into local_group_sizes for both grouped_quantize and
grouped_gemm. cuBLAS now skips both 0-token experts and the trailing
overalloc tail. Mirror the residual spec change on the bwd
(local_group_sizes residual moves from P() to ep2_spec).

(2) Per-group jnp.where zero-fill on wgrad outputs.

cuBLAS grouped_gemm skips groups with size_g == 0 without zero-filling
the corresponding out[g, :, :] slice (cublaslt_grouped_gemm.cu lines
2086/2096). For a shard hosting an expert that received zero tokens
globally, d_wo / d_wi_combined for that expert is left uninit, which
propagates NaN straight into the user's optimizer state.

Add wgrad_group_active = (local_group_sizes > 0)[:, None, None] in
_ffn_bwd_per_shard and apply via jnp.where on d_wo (right after the wo
wgrad) and d_wi_combined (right after the fused wi_0+wi_1 wgrad).
Mask shape is (num_local_experts, 1, 1) so cost is negligible.

(3) Drop the lax.cond zero-init guard on r_tok in _moe_fwd_rule._body.

Previously a jax.lax.cond(jnp.any(r_w != 0), identity, zeros_like)
wrapper around recv_tokens worked around tex.ep_dispatch_fwd leaving
the recv buffer uninit on fully-empty-receiver ranks. With (1) in
place, cuBLAS skips experts whose group_sizes == 0 and the per-row
trailing tail of dispatched recv_tokens is unread by every downstream
consumer (subsequent grouped_gemms read only sum(group_sizes) rows;
ep_combine and ep_dispatch_bwd are handle_mem-aware). The only
per-row consumer that would propagate the tail is grouped_dbias
(per-row segment_sum), which only runs when has_bias=True, and that
FFN bias path is currently gated upstream (cuBLAS grouped_gemm has
no fused bias on Hopper yet; PR 3083 adds the pure-JAX bias add).
With (2) handling the user-visible wgrad-NaN risk on 0-token experts,
the lax.cond is now redundant. Replace with a NOTE pointing at the
two follow-ups that would force its reintroduction:
  - a future caller that reads the full recv tile non-group-aware
    (e.g. an inspect probe), or
  - the FFN bias path landing, which would resurrect grouped_dbias.

Also rewrite the _ffn_fwd_per_shard and _ffn_bwd_per_shard docstrings
to spell out the per-row vs per-group uninit semantics so the next
person debugging a NaN here has the invariants written down.

Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
…IA#3116)

Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias ->
use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The
two flags are siblings (they enable two different bias buffers) but
the old names suggested ``use_bias`` was the general fallback, which
wasn't the intent. The new names make the FFN-vs-routing distinction
obvious from the call site.

* transformer_engine/jax/flax/moe.py
    use_bias -> use_ffn_bias  (dataclass field + branch in __call__
    + docstring entry)
    use_expert_bias -> use_expert_routing_bias  (same)
* tests/jax/test_te_ep_moe.py
    _make_block(use_expert_bias=...) -> use_expert_routing_bias
    sigmoid-bias-strong config key updated
    _reference_kwargs_from_config now reads use_expert_routing_bias

``_MoEBlock`` is still the experimental underscore-prefixed alias
(no public ``MoEBlock`` export yet), so the rename is API-safe.

The pre-resync legacy tests (``test_moe_vjp.py``,
``test_multiprocess_moe_vjp.py``) are intentionally not updated --
they already reference removed APIs like ``PermutationBackend`` and
need a separate post-resync cleanup pass.

Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications)

Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on
``transformer_engine/jax/moe.py``. All changes are confined to a
single file because each review thread targets a localized region
and splitting mid-file would risk reordering bugs.

Per review thread:

1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't
    seen something like this required for our other VJPs."
   -- Expand the helper's docstring to spell out exactly why MoE
   needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent
   from ep_dispatch_bwd with an fp32 cotangent from
   fused_topk_with_score_function_bwd (which the fwd's
   logits_2d -> fp32 promotion forces). Without the cast, ``d_x``
   surfaces at fp32 even when ``x`` is bf16, doubling activation
   grad bandwidth and breaking any downstream LN bwd that pins a
   bf16 layout. (Review thread "Why do we need this utility
   function?".)

2. "Why is this dtype casting required? I don't recall us needing
    it for the non-MoE LNMLP block."
   -- Expand the comment above the bwd activation fp32 promotion
   to explain the MoE-specific math: LN+MLP's silu sits behind a
   downstream LN that absorbs the bf16 rounding error, while
   MoE's silu sits on the *expert* side of routing -- the bf16
   rounding rides directly into expert_outputs and is summed
   across topk experts by ep_combine. Bf16 silu alone drifts ~1%
   vs fp32 silu and compounds through wo->combine into the ~1.4%
   per-element parity gap we measured against the pure-JAX
   softmax reference. Mirroring the fwd's fp32 promotion in the
   bwd keeps silu' in lock-step with silu. (Review thread on
   "# Activation bwd. Mirror the fwd's fp32 promotion of
   silu+multiply".)

3. "Do we have a use-case for user-specified alignments beyond
    128 currently? ... it'd make sense to instead hardcode
    _ALIGN_SIZE = 128 as a constant at the top of the file for
    now to simplify this MoEBlock API. We can always expand the
    API to support a user-specified align size in the future."
   -- Implement the suggestion. Drop ``align_size`` from
   ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public
   ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from
   ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align
   = max(int(align_size), 128)`` with the new module-level
   ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring
   accordingly. (Review thread on
   "natural_spe = num_ep * max_tokens_per_rank".)

4. "Which axis name inputs are physical mesh axes and why can be
    logical axes? ... No need to make any changes for now, I just
    want to assess which are which and then we can discuss if it
    makes sense to support logical on some/all or if some are
    required to be physical axes."
   -- Add an "Axis-name parameters" section to ``moe()``'s
   docstring listing which kwargs are physical mesh axes
   (``ep_axis``, ``data_parallelism_axes`` -- they index
   ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size``
   and to construct the ``P((dp..., ep), None, None)`` for
   ``jax.lax.with_sharding_constraint``) vs logical axes
   (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
   ``wo_kernel_axes`` -- resolved via the Flax logical-axis
   rules). Also document why ``ep_axis`` / ``data_parallelism_axes``
   are intentionally non-logical: the EP comm-group construction
   (``dp_color = rank // ep_size``) and the bootstrap signature
   check both require concrete integer sizes. (Review thread on
   "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".)

5. "Is this NaN filtering a debugging artifact or something we
    need in the final version?"
   -- Strengthen the inline comment above
   ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)``
   to explicitly call this out as a CORRECTNESS REQUIREMENT, not
   a debugging artifact: it covers the sigmoid+K>1 underflow
   path where top-K sigmoid scores all round to zero and the
   ``weights / (weights.sum + 1e-20)`` normalisation emits NaN.
   Observationally the filter is a no-op on the dense unit-test
   distributions, but it must stay in for sparse / production
   routing. (Review thread on
   "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).")

Not addressed in this commit (intentional):

* Review thread on the ``align_size: int = 0`` placeholder in
  ``flax/moe.py`` ("Placeholder comment for me to fix this so
  align_size is inferred automatically based on the recipe and
  doesn't need to be specified by the user"). That's
  jberchtold's own follow-up.
* Review thread on the explicit ``tree_flatten`` /
  ``tree_unflatten`` on ``_Ctx`` ("better to use the
  ``@flax_struct.dataclass``"). Deferred to a separate, testable
  commit because changing a ``custom_vjp`` residual's pytree
  registration touches subtle ordering / None-handling semantics
  that warrant their own bisect surface.
* Review thread on ``use_bias`` / ``use_expert_bias`` renames --
  handled in the immediately preceding commit
  ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``.
* Review thread on the ``expert_bias`` fp32 init -- already
  resolved during the Phuong PR NVIDIA#3036 resync (the redundant
  ``jnp.float32`` second-dtype argument on ``self.param`` was
  dropped; ``expert_bias`` now lives at ``self.dtype``).

Signed-off-by: tdophung <tdophung@nvidia.com>
…aN sanitizer

* Rewrite the inline justifications added in 078a7d80 so each one
  reads as standalone code documentation, not as a reply to a
  reviewer: drop "per PR NVIDIA#3116 review", "review feedback",
  "Renamed from ... per PR ..." and similar PR/thread references
  from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py.
  Technical content (why the fp32 promotion is needed for the MoE
  silu+multiply, why _with_sharding_constraint_cast_bwd exists,
  physical-vs-logical axis split in moe() docstring, the 128
  alignment rationale) is preserved and reframed to be useful to
  a reader who has no PR context.

* Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs)
  guard. Tracing fused_topk_with_score_function.cu shows the
  kernel divides by sum_scores + 1e-20, so finite non-negative
  sigmoid scores cannot produce NaN here; the filter was only
  defense against upstream NaNs, which would mask a real
  regression if anything ever did start producing them.

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/te_ep_integration branch from 68617ea to fe44697 Compare June 12, 2026 00:15
tdophung added 2 commits June 12, 2026 11:31
The SwiGLU intermediate (activation inputs gate_proj_out/up_proj_out,
silu+multiply, and activation output) was previously promoted to fp32
in _ffn_fwd_per_shard and again in _ffn_bwd_per_shard, then cast back
to the wi/wo GEMM dtype. The promotion bought nothing: the activation
inputs come out of the wi grouped_gemm in bf16, the activation output
is consumed by the wo GEMM (or wo's quantizer for FP8/FP4) in the same
dtype, and storing higher precision than either consumer is wasted
bandwidth.

* _ffn_fwd_per_shard: drop the .astype(jnp.float32) on gate_proj_out
  and up_proj_out and the trailing .astype(sorted_x.dtype). The
  multiply now stays in the wi GEMM output dtype end-to-end.
* _ffn_bwd_per_shard: symmetric simplification. jax.vjp(act_fn, ...)
  runs at bf16, both d_intermediate * silu' and d_intermediate * up
  stay at bf16, no casts. silu' is now consistent with silu (both
  bf16) so the chain rule composes cleanly without the prior fp32
  detour.
* tests/jax/test_te_ep_moe.py::_pure_jax_moe_reference: drop the
  matching fp32 silu in the parity reference so the test compares
  bf16-vs-bf16. Parity tolerance was not loosened; expect the
  comparison to tighten now that both sides round silu identically.

Also fix an inaccurate inline comment at the apply_topk_weights_early
fwd branch: the bf16 requirement on expert_outputs is enforced by
ep_bootstrap (which rejects max_token_dtype != bf16 and sizes the
NCCL EP HT mega-buffer for 2-byte slots accordingly), not by a
runtime assert in the combine FFI.

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review June 12, 2026 22:16
@tdophung tdophung requested a review from ptrendx as a code owner June 12, 2026 22:16
@greptile-apps

greptile-apps Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR wires NVIDIA's NCCL-backed Expert Parallelism (EP) primitives into the TransformerEngine JAX MoE block, replacing the previous PURE_JAX/TRITON permutation backends with a single tex.ep_dispatch / tex.ep_combine path under one fused jax.custom_vjp. It adds a new C++ EPBackend singleton, JAX FFI primitives with SPMD custom-partitioning rules, and an eager ep_bootstrap helper that exchanges NCCL UIDs across ranks.

  • New TE EP stack: EPBackend C++ singleton wraps NCCL EP's CreateGroup/UpdateHandle/Dispatch/Combine; the JAX side exposes ep_prepare/ep_dispatch_fwd/ep_combine_fwd and matching _bwd primitives with partition rules and abstract eval.
  • moe.py rewrite: Global-view EP collectives outside shard_map, per-expert GEMMs inside a smaller shard_map; a new _Ctx pytree carries residuals across the fwd→bwd boundary.
  • _MoEBlock API changes: PermutationBackend, align_size, use_bias, use_expert_bias removed or renamed to use_ffn_bias, use_expert_routing_bias, apply_topk_weights_early.

Confidence Score: 3/5

The EP dispatch/combine forward path may silently produce NaN outputs when recv_topk_weights contains NaN at padded slots (a condition the backward explicitly acknowledges); the bootstrap rank-grouping formula is unvalidated against the mesh layout.

Two distinct correctness gaps exist on the hot path: the forward masking step omits NaN sanitization the backward carefully applies to the same tensor, and ep_bootstrap assumes a specific rank ordering never cross-checked against the active JAX mesh.

transformer_engine/jax/moe.py (forward NaN masking ~line 827) and transformer_engine/jax/ep.py (bootstrap rank-layout ~line 110)

Important Files Changed

Filename Overview
transformer_engine/jax/moe.py Major rewrite to TE EP NCCL path; forward missing NaN sanitization of recv_topk_weights before combine masking (backward handles this correctly).
transformer_engine/jax/ep.py New EP bootstrap and custom_vjp wrappers; bootstrap rank-layout assumption unvalidated against JAX mesh.
transformer_engine/jax/cpp_extensions/ep.py New JAX FFI primitives with correct custom_partitioning rules and abstract eval.
transformer_engine/common/ep/ep_backend.cpp NCCL EP singleton backend with LRU handle cache; fallback path silently limits to one layer config per process.
transformer_engine/jax/flax/moe.py Mechanical API renames; consistent and safe.
transformer_engine/jax/csrc/extensions/ep.cpp C++ FFI bridge; handle_mem forwarding and stream plumbing look correct.

Reviews (1): Last reviewed commit: "remove useless comments" | Re-trigger Greptile

Comment on lines +827 to +833
w = recv_topk_weights[..., None].astype(expert_outputs.dtype)
mask_bool = (recv_topk_weights != 0)[..., None]
weighted = jnp.where(mask_bool, expert_outputs * w, jnp.zeros_like(expert_outputs))
output = tex.ep_combine_fwd(
cfg,
handle_mem,
weighted,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NaN-in-recv_topk_weights not sanitized in forward

The backward (_moe_bwd_rule) explicitly documents that ep_dispatch_fwd can write NaN into recv_topk_weights at padded slots and carefully sanitizes it before forming mask_bool. The forward uses mask_bool = (recv_topk_weights != 0)[..., None] without sanitization: NaN != 0 evaluates to True in IEEE 754, so padded slots with NaN weights are incorrectly treated as active, and expert_outputs * NaN = NaN propagates into the weighted buffer passed to ep_combine_fwd. If the NCCL EP combine kernel ever touches those padded positions, the output will silently contain NaN. The backward already shows the correct fix: replace with recv_w_clean = jnp.where(jnp.isnan(recv_topk_weights), 0, recv_topk_weights) before building w and mask_bool.

Comment on lines +110 to +126
ep_resource = gsr.ep_resource
if ep_resource is None:
raise ValueError(
"ep_bootstrap requires MeshResource.ep_resource to be set; enter a"
" global_shard_guard(MeshResource(..., ep_resource=<axis name>)) before bootstrap."
)
ep_size = get_mesh_axis_size(ep_resource)
outer_axis = _ep_outer_axis()
if outer_axis is None:
if world_size != ep_size:
raise ValueError(
f"ep_bootstrap: world_size ({world_size}) > ep_size ({ep_size}) but neither"
" MeshResource.dp_resource nor fsdp_resource is set; name the outer axis so"
" EP-output tensors can shard across EP groups."
)
num_ep_groups = 1
else:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Rank-layout assumption in ep_bootstrap is unvalidated

dp_color = rank // ep_size hard-codes a specific rank-to-group mapping: EP ranks occupy contiguous blocks of ep_size within the global rank space. This is only correct when the JAX mesh assigns its ep_resource axis as the innermost (fastest-varying) dimension. A mesh constructed as Mesh(devices.reshape(ep, dp), ("ep", "dp")) assigns rank 1 to ep=0, dp=1, yet dp_color = 1 // ep_size = 0 incorrectly merges rank 0 and rank 1 into one communicator. There is no guard that checks mesh.device_ids ordering against the formula, so a silently wrong NCCL communicator is built.

Comment on lines +183 to +199
hidden_dim=hidden_dim,
)
)


def _default_out_partition_spec():
"""Leading-axis default: ``(("dp","ep"),)`` if DP/FSDP is set, else ``("ep",)``."""
gsr = global_mesh_resource()
if gsr.ep_resource is None:
raise ValueError(
"ep_resource is not set on the active MeshResource; pass out_sharding=... explicitly."
)
outer = _ep_outer_axis()
leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource
return (leading,)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _dispatch_bwd silently ignores cotangents for handle_mem and token_counts

ep_dispatch returns 4 primals, so g_outputs has 4 elements. Only g_outputs[0] and g_outputs[1] are consumed; the cotangents for handle_mem (uint8) and token_counts (int32) are silently dropped. JAX sets these to zero for non-float outputs so behaviour is correct today, but an explicit assertion on len(g_outputs) would protect against future arity changes.

Comment on lines +229 to +248
// ---------------------------------------------------------------------------

size_t EPBackend::cache_cap_locked() {
if (handle_cache_cap_ == 0) {
const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE");
if (cap_env != nullptr) {
const int64_t v = static_cast<int64_t>(std::atol(cap_env));
if (v < 0) {
// Unlimited cache. WAR for JAX until XLA fixes handle_mem
// reloc between runs.
handle_cache_cap_ = SIZE_MAX;
} else {
NVTE_CHECK(v > 0,
"NVTE_EP_HANDLE_CACHE_SIZE=0 is invalid; use -1 for unlimited or a positive "
"cap.");
handle_cache_cap_ = static_cast<size_t>(v);
}
} else {
handle_cache_cap_ = 4096;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 fallback_layer_cfg_ limits the backend to one unique layer config per process

The fallback path reconstructs handles using a process-wide cached config, implicitly requiring all MoE layers in the process to share the same (top_k, alignment) pair. Users stacking layers with different top_k values will hit a NVTE_CHECK at runtime without a clear message pointing to this constraint. Consider documenting this in ep.h near NVTEEpLayerConfig.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants